mmdetection源码详细解读 | 您所在的位置:网站首页 › show result pyplot › mmdetection源码详细解读 |
目录
简介1. 测试代码2. mmdetection/mmdet/apis2.1 mmdetection/mmdet/apis/inference.py2.1.1 init_detector2.1.2 inference_detector
3. build_detector()
简介
GitHub地址:https://github.com/open-mmlab/mmdetection.各模型的权重可以在model_zoo.md上下载。mmdetection官方使用教程https://mmdetection.readthedocs.io/en/latest/(强烈建议)
1. 测试代码
import mmcv
from mmdet.apis import init_detector, inference_detector, show_result_pyplot
config_file = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:1')
# test a single image and show the results
img = 'demo.jpg' # or img = mmcv.imread(img), which will only load it once
result = inference_detector(model, img)
# visualize the results in a new window
show_result_pyplot(model, img, result)
新建一个工程,其目录结构为: 该文件夹下有三个文件: inference.py:用于初始化模型、前向推理、读取图片、显示检测结果等。train:用于训练。test:用于测试。函数作用、输入参数、输出参数直接见注解。这里给出几个例子: from mmdet.apis import init_detector, inference_detector, show_result_pyplot config_file = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py' checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' # build the model from a config file and a checkpoint file model = init_detector(config_file, checkpoint_file, device='cuda:0') # device 可以是'cpu'、'cuda:0'、'cuda:1'等。 2.1.2 inference_detector def inference_detector(model, img): """Inference image(s) with the detector. Args: model (nn.Module): The loaded detector. imgs (str/ndarray or list[str/ndarray]): Either image files or loaded images. Returns: If imgs is a str, a generator will be returned, otherwise return the detection results directly. """ cfg = model.cfg device = next(model.parameters()).device # model device # build the data pipeline test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] test_pipeline = Compose(test_pipeline) # prepare data data = dict(img=img) data = test_pipeline(data) data = collate([data], samples_per_gpu=1) if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device])[0] else: # Use torchvision ops for CPU mode instead for m in model.modules(): if isinstance(m, (RoIPool, RoIAlign)): if not m.aligned: # aligned=False is not implemented on CPU # set use_torchvision on-the-fly m.use_torchvision = True warnings.warn('We set use_torchvision=True in CPU mode.') # just get the actual data from DataContainer data['img_metas'] = data['img_metas'][0].data # forward the model with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) return result 3. build_detector()model = build_detector(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) 根据配置文件构建神经网络以configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py配置文件为例,cfg.model,cfg.train_cfg,cfg.test_cfg均为字典类型,分别与配置文件中的内容相对应。 DETECTORS = Registry('detector') def build_detector(cfg, train_cfg=None, test_cfg=None): """Build detector.""" return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) Registry是一个大工厂,工厂中包含了很多小仓库,共有7个小仓库,这些仓库在train.py开始import 模块时就自动创建,并且每个仓库中都放了各种样的小商品,比如在detectors/__init__.py中可以看到,仓库DETECTORS的小商品有:。 BACKBONES = Registry('backbone') NECKS = Registry('neck') ROI_EXTRACTORS = Registry('roi_extractor') SHARED_HEADS = Registry('shared_head') HEADS = Registry('head') LOSSES = Registry('loss') DETECTORS = Registry('detector') # detectors/__init__.py from .atss import ATSS from .base import BaseDetector from .cascade_rcnn import CascadeRCNN from .fast_rcnn import FastRCNN from .faster_rcnn import FasterRCNN from .fcos import FCOS from .fovea import FOVEA from .fsaf import FSAF from .gfl import GFL from .grid_rcnn import GridRCNN from .htc import HybridTaskCascade from .mask_rcnn import MaskRCNN from .mask_scoring_rcnn import MaskScoringRCNN from .nasfcos import NASFCOS from .point_rend import PointRend from .reppoints_detector import RepPointsDetector from .retinanet import RetinaNet from .rpn import RPN from .single_stage import SingleStageDetector from .two_stage import TwoStageDetector 此时参数cfg是字典类型,执行build_from_cfg() def build(cfg, registry, default_args=None): if isinstance(cfg, list): modules = [build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg] return nn.Sequential(*modules) else: return build_from_cfg(cfg, registry, default_args) 比如在fcos.py文件中,修饰器@DETECTORS.register_module()的作用就是将创建好的小商品FCOS放入仓库DETECTORS中。 @DETECTORS.register_module() class FCOS(SingleStageDetector): """Implementation of `FCOS `_""" def __init__(self, backbone, neck, bbox_head, train_cfg=None, test_cfg=None, pretrained=None): super(FCOS, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained) 在register_module()函数中,你可以看到它的用法,小商品的名字就是在注册时指定的,而注册时的名字就是类名FCOS或ResNet等。所以你要创建自己的检测器,你需要构建一个类,把这个类当做小商品,然后将其放在相应的仓库中(注册)。 def register_module(self, name=None, force=False, module=None): """Register a module. A record will be added to `self._module_dict`, whose key is the class name or the specified name, and value is the class itself. It can be used as a decorator or a normal function. Example: >>> backbones = Registry('backbone') >>> @backbones.register_module() >>> class ResNet: >>> pass >>> backbones = Registry('backbone') >>> @backbones.register_module(name='mnet') >>> class MobileNet: >>> pass >>> backbones = Registry('backbone') >>> class ResNet: >>> pass >>> backbones.register_module(ResNet) obj_type='FCOS',在 def build_from_cfg(cfg, registry, default_args=None): """Build a module from config dict.""" args = cfg.copy() obj_type = args.pop('type') # 'FCOS' if is_str(obj_type): obj_cls = registry.get(obj_type) if obj_cls is None: raise KeyError( f'{obj_type} is not in the {registry.name} registry') elif inspect.isclass(obj_type): obj_cls = obj_type else: raise TypeError( f'type must be a str or valid type, but got {type(obj_type)}') if default_args is not None: for name, value in default_args.items(): args.setdefault(name, value) return obj_cls(**args) |
今日新闻 |
推荐新闻 |
专题文章 |
CopyRight 2018-2019 实验室设备网 版权所有 |